import json
import os
import re
from os.path import expanduser

import numpy as np
import pandas as pd

home = expanduser("~").replace("\\", "/")


def load_config() -> dict:
    config_path = os.path.join(home, ".shares", "config.json")
    if os.path.exists(config_path):
        with open(os.path.join(home, ".shares", "config.json"), mode='r') as f:
            return json.load(f)
    else:
        config_dir = os.path.join(home, ".shares")
        if not os.path.exists(config_dir):
            os.makedirs(os.path.join(home, ".shares"))
        return {}


def store_config(config: dict) -> None:
    with open(os.path.join(home, ".shares", "config.json"), 'w', encoding='utf-8') as f:
        json.dump(config, f)


def auto_change_column_dtypes(df: pd.DataFrame, type_dict) -> pd.DataFrame:
    df = df.convert_dtypes()
    for col in df.columns:
        if col in type_dict.keys() and type_dict[col] != 'str':
            if "date" in type_dict[col]:
                df[col] = pd.to_datetime(df[col].replace('', np.nan).astype(str), errors="coerce")
            else:
                df[col] = df[col].replace('', np.nan).astype(type_dict[col], errors='ignore')
        elif 'string' in str(df[col].dtype) or 'object' in str(df[col].dtype):
            s_not_empty: pd.Series = df[col][df[col] != '']
            if len(s_not_empty) > 0:
                str_value = s_not_empty.iloc[0]
                if isinstance(str_value, str) and (
                        re.match("^[12][0-9][0-9][0-9]-[01][0-9]-[0-3][0-9]$", str_value) or re.match(
                    "^[12][0-9][0-9][0-9][01][0-9][0-3][0-9]$", str_value)):
                    df[col] = pd.to_datetime(df[col], errors='coerce')
                elif s_not_empty.str.match("(^[1-9][0-9]+$)|(^[0-9]$)").all() and not (
                        s_not_empty.str.len() == 6).all():
                    df[col] = df[col].replace('', np.nan).astype('Int64', errors='ignore')
                elif re.match("^-?[0-9]+[.][0-9]+$", str_value):
                    df[col] = df[col].replace('', np.nan).astype(float, errors='ignore')
        elif 'Int' in str(df[col].dtype) and "_date" in col:
            df[col] = pd.to_datetime(df[col].replace(0, np.nan).astype(str, errors='ignore'), errors='coerce')

    return df


def nearest_trade_days(date_series, trade_days) -> pd.Series:
    """
    给定日期最近交易日
    """
    sorted_date = sorted(date_series.dropna().unique())
    loc_array = np.searchsorted(trade_days, sorted_date, side='left')
    loc_array[loc_array >= len(trade_days)] = len(trade_days) - 1
    trade_date_convert_dict = {pd.to_datetime(k).strftime("%Y-%m-%d"): v for k, v in
                               zip(sorted_date, trade_days[loc_array])}
    return date_series.apply(
        lambda cell: trade_date_convert_dict[
            pd.to_datetime(cell, errors="ignore").strftime("%Y-%m-%d")] if not pd.isna(cell) else pd.NaT)


def normalize_code(codes) -> list:
    """规范化代码"""
    fix_codes = []
    for code in codes:
        if code.endswith('.SZ') or code.endswith('.SH') or code.endswith('.BJ'):  # tushare格式
            fix_codes.append(code)
        elif len(code) == 8 and (code.startswith('sh') or code.startswith('sz') or code.startswith('bj')):
            code_fix = code[2:8] + code[0:2].replace('sh', '.SH').replace('sz', '.SZ').replace('bj', ".BJ")
            fix_codes.append(code_fix)
        elif len(code) == 9 and (code.startswith('sh') or code.startswith('sz') or code.startswith('bj')):
            code_fix = code[3:9] + code[0:3].replace('sh.', '.SH').replace('sz.', '.SZ').replace('bj.', ".BJ")
            fix_codes.append(code_fix)
        elif len(code) == 6:
            if code[0] in ['0', '1', '3']:
                fix_codes.append(code + ".SZ")
            elif code[0] in ['8', '4']:
                fix_codes.append(code + ".BJ")
            else:
                fix_codes.append(code + ".SH")
        else:
            fix_codes.append(code)
    return fix_codes


__all__ = ["load_config", 'store_config', 'auto_change_column_dtypes', 'nearest_trade_days', 'normalize_code']
